# -*- Python -*-
# bi_v1 with encoder and decoder parameters shared
include 'models/bi_v1.gin'

encoder/transformer.make_layer_stack.layers = [
    ('self_attention', @mesh_tensorflow.transformer.transformer_layers.SelfAttention),
    ('dense_relu_dense', @mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]
decoder/transformer.make_layer_stack.layers = [
    ('self_attention', @mesh_tensorflow.transformer.transformer_layers.SelfAttention),
    ('enc_dec_attention', @mesh_tensorflow.transformer.transformer_layers.EncDecAttention),
    ('dense_relu_dense', @mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]

make_bitransformer.encoder_name = "decoder"
